Skip to content

feat(megatron): split-API train-step state machine on MegatronPolicyWorker#2683

Open
mehraakash wants to merge 7 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/split_train_mcore
Open

feat(megatron): split-API train-step state machine on MegatronPolicyWorker#2683
mehraakash wants to merge 7 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/split_train_mcore

Conversation

@mehraakash

Copy link
Copy Markdown

Adds begin_train_step / train_microbatch / finish_train_step / abort_train_step on MegatronPolicyWorkerImpl, mirroring the DTensor v1/v2 implementations but adapted for mcore's contiguous grad bucket + pipeline-schedule reduce path.

Mechanism:

  • begin_train_step: zero_grad_buffer + optimizer.zero_grad, store loss_fn / gbs / mbs / local_valid_seqs/toks accumulators on _train_step_state, and null model.config.grad_sync_func (saved for restore) so the PP scheduler's direct reduce dispatch cannot bypass no_sync.
  • train_microbatch(data): wrap one megatron_forward_backward invocation in with self.model.no_sync(): so mcore DDP hooks accumulate param.main_grad locally without dispatching the cross-DP reduce. Pass global_valid_seqs/toks=tensor(1.0) so the loss returns un-normalized sums; backward deposits raw d(sum)/dθ. Accumulate local mask sums + per-mb metrics + the total pipeline-microbatch count (for finish-time MoE aux-loss scaling).
  • finish_train_step: all_reduce mask sums to get true N (toks for TOKEN_LEVEL loss, seqs for SEQUENCE_LEVEL), call self.model.scale_gradients(1/N), then the one true cross-DP reduce via start_grad_sync + finish_grad_sync, optimizer.step (clips internally), restore grad_sync_func, scheduler.step(increment=gbs). Rescale per-mb metrics by 1/N (linear-in-1/N math), aggregate, surface global counts.
  • abort_train_step: restore grad_sync_func, zero_grad_buffer + zero_grad, drop state. trainer_version unchanged.

Sync train() is left untouched.

Includes CPU unit tests at tests/unit/models/policy/test_megatron_split_state.py covering the lifecycle and call-order invariants (no_sync wrap, grad_sync_func save/restore, mask-sum accumulation, N selection by loss_type, abort idempotence, MoE scaling). Marked pytest.mark.mcore so they run only in mcore-enabled CI containers.

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

@mehraakash mehraakash requested review from a team as code owners June 4, 2026 06:09
@copy-pr-bot

copy-pr-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@mehraakash

Copy link
Copy Markdown
Author

/ok to test

@mehraakash

Copy link
Copy Markdown
Author

/ok to test 3d24224

@mehraakash

Copy link
Copy Markdown
Author

/ok to test cb65400

@mehraakash

Copy link
Copy Markdown
Author

/ok to test 108ec17

mehraakash added a commit to mehraakash/RL that referenced this pull request Jun 8, 2026
Drives the begin/train_microbatch/finish split API in NVIDIA-NeMo#2683 and
group, per-group prepare_logprobs (when configured) -> advantage_pump
-> train_microbatch_from_meta (queued), one finish_train_step + one
clear_samples at end-of-step. Buffer capacity released per group, not
per step.

- StalenessSampler.select_one_group: picks one eligible prompt group;
  same predicate as select_indices, sort by (lag, indices[0]).
- SingleControllerConfig.target_prompt_groups_per_step: explicit per-
  step admission count; validated against min_prompt_groups_per_batch.
- _reap_in_flight_nonblocking: ray.wait(timeout=0) drain helper.
- DryRunTrainer: split-API stub with begin/microbatch/finish/abort
  invariants for dry-run tests.
- 7 streaming dry-run tests: arrival order, finish-time trainer_version
  tick, strict on-policy filter, long-tail overlap, abort idempotence,
  empty-step no-op, single clear_samples per step.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash

Copy link
Copy Markdown
Author

/ok to test 5753eb4

mehraakash and others added 6 commits June 8, 2026 15:29
…orker

Adds begin_train_step / train_microbatch / finish_train_step / abort_train_step
on MegatronPolicyWorkerImpl, mirroring the DTensor v1/v2 implementations but
adapted for mcore's contiguous grad bucket + pipeline-schedule reduce path.

Mechanism:
- begin_train_step: zero_grad_buffer + optimizer.zero_grad, store loss_fn /
  gbs / mbs / local_valid_seqs/toks accumulators on _train_step_state, and
  null model.config.grad_sync_func (saved for restore) so the PP scheduler's
  direct reduce dispatch cannot bypass no_sync.
- train_microbatch(data): wrap one ``megatron_forward_backward`` invocation
  in ``with self.model.no_sync():`` so mcore DDP hooks accumulate
  ``param.main_grad`` locally without dispatching the cross-DP reduce.
  Pass ``global_valid_seqs/toks=tensor(1.0)`` so the loss returns
  un-normalized sums; backward deposits raw d(sum)/dθ. Accumulate local
  mask sums + per-mb metrics + the total pipeline-microbatch count
  (for finish-time MoE aux-loss scaling).
- finish_train_step: all_reduce mask sums to get true N (toks for
  TOKEN_LEVEL loss, seqs for SEQUENCE_LEVEL), call
  self.model.scale_gradients(1/N), then the one true cross-DP reduce via
  start_grad_sync + finish_grad_sync, optimizer.step (clips internally),
  restore grad_sync_func, scheduler.step(increment=gbs). Rescale per-mb
  metrics by 1/N (linear-in-1/N math), aggregate, surface global counts.
- abort_train_step: restore grad_sync_func, zero_grad_buffer + zero_grad,
  drop state. ``trainer_version`` unchanged.

Sync ``train()`` is left untouched.

Includes CPU unit tests at tests/unit/models/policy/test_megatron_split_state.py
covering the lifecycle and call-order invariants (no_sync wrap,
grad_sync_func save/restore, mask-sum accumulation, N selection by
loss_type, abort idempotence, MoE scaling). Marked pytest.mark.mcore so
they run only in mcore-enabled CI containers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Pre-existing zero-error file from NVIDIA-NeMo#2078 (Eagle3) that was never added
to the project-includes whitelist. Carrying the fix forward in this
PR to unblock the lint job.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
The file is introduced by NVIDIA-NeMo#2692 (DTensor PR), not by this branch.
Whitelisting it here causes pyrefly to fail with 'No Python files
matched pattern' since the file does not exist on mcore.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
cloudpickle traverses globals/closures of each method when serializing
the Ray actor class. With torch 2.11, 'config' in __code__.co_names
matches torch.distributed.config (a non-pickleable ConfigModuleInstance),
breaking actor creation with:
  TypeError: cannot pickle 'ConfigModuleInstance' object
  Could not serialize the actor class ...MegatronPolicyWorker

Same workaround as the existing sync train(): read 'config' via
getattr-by-string in begin/finish/abort_train_step.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash

Copy link
Copy Markdown
Author

/ok to test 3bc3244

test_megatron_split_state.py eagerly imports megatron_policy_worker
which transitively imports megatron.bridge. In non-mcore shards (Models,
Vllm, Sglang, Automodel_Policy), megatron.bridge isn't installed so
collection of this file fails, killing every other test in the shard.

pytest.importorskip stops collection cleanly when megatron.bridge is
not available. The pytest.mark.mcore filter still ensures these tests
only run in mcore shards.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash

Copy link
Copy Markdown
Author

/ok to test cd456b6

# state machine; this mixin just gates them on TQ-presharded data.

@wrap_with_nvtx_name("policy_worker/begin_train_step_presharded")
def begin_train_step_presharded(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

begin_train_step_presharded / finish_train_step_presharded / abort_train_step_presharded are pure pass-through — args 1:1 forwarded to the backend, no KVBatchMeta → BatchedDataDict translation (unlike train_microbatch_presharded, which does _fetch + _attach_or_repack_pack_metadata). The TQPolicy dispatch in #2700 could call run_all_workers_single_data("begin_train_step", ...) / "finish_train_step" / "abort_train_step" directly against the backend method names, dropping these three wrappers (~60 lines + the type: ignore[attr-defined] noise). The mixin still hosts train_microbatch_presharded where it actually earns its keep. wdyt?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L0 Run doctests and unit tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants